import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
from torch.utils.data import DataLoader
from scipy.stats import kendalltau


# from exps.predictors.utils.encodings import encode
from exps.nasbenchs.natsbench import Natsbench
from exps.predictors.predictor import Predictor
from exps.predictors.utils import loguniform
from exps.predictors.utils import AverageMeterGroup

# we refer to the model TNASP.

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device:', device)


def to_cuda(obj, _device):
    if torch.is_tensor(obj):
        return obj.to(_device)
    if isinstance(obj, tuple):
        return tuple(to_cuda(t, _device) for t in obj)
    if isinstance(obj, list):
        return [to_cuda(t, _device) for t in obj]
    if isinstance(obj, dict):
        return {k: to_cuda(v, _device) for k, v in obj.items()}
    if isinstance(obj, (int, float, str)):
        return obj
    raise ValueError("'%s' has unsupported type '%s'" % (obj, type(obj)))


def normalize_adj(adj):
    # Row-normalize matrix
    last_dim = adj.size(-1)
    rowsum = adj.sum(2, keepdim=True).repeat(1, 1, last_dim)
    return torch.div(adj, rowsum)


def graph_pooling(inputs, num_vertices):
    out = inputs.sum(1)
    num_vertices = num_vertices.unsqueeze(-1).expand_as(out)
    return torch.div(out, num_vertices)


def accuracy_mse(prediction, target, scale=100.0):
    prediction = prediction.detach() * scale
    target = (target) * scale
    return F.mse_loss(prediction, target)


class ScaledDotProductAttention(nn.Module):
    """ Scaled Dot-Product Attention """

    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)

    def forward(self, q, k, v, mask=None):
        attn = torch.matmul(q / self.temperature, k.transpose(2, 3))

        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)

        attn = self.dropout(F.softmax(attn, dim=-1))
        output = torch.matmul(attn, v)

        return output, attn


class PositionalEncoding(nn.Module):

    def __init__(self, d_hid, n_position=200):
        super(PositionalEncoding, self).__init__()

        # Not a parameter
        self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))

    def _get_sinusoid_encoding_table(self, n_position, d_hid):
        """ Sinusoid position encoding table """

        # TODO: make it with torch instead of numpy

        def get_position_angle_vec(position):
            return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

        sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
        sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
        sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

        return torch.FloatTensor(sinusoid_table).unsqueeze(0)  # (1,N,d)

    def forward(self, x):
        # x(B,N,d)
        # return x +
        pos_enc = self.pos_table[:, :x.size(1)].clone().detach()
        enc_output = x + pos_enc
        return enc_output


class MultiHeadAttention(nn.Module):
    """ Multi-Head Attention module """

    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
        self.fc = nn.Linear(n_head * d_v, d_model, bias=False)

        self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)

        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

    def forward(self, q, k, v, mask=None):
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)

        residual = q

        # Pass through the pre-attention projection: b x lq x (n*dv)
        # Separate different heads: b x lq x n x dv
        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        # Transpose for attention dot product: b x n x lq x dv
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        if mask is not None:
            mask = mask.unsqueeze(1)  # For head axis broadcasting.

        q, attn = self.attention(q, k, v, mask=mask)

        # Transpose to move the head dimension back: b x lq x n x dv
        # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
        q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)

        # q (sz_b,len_q,n_head,N * d_k)
        q = self.dropout(self.fc(q))
        q += residual

        q = self.layer_norm(q)

        return q, attn


class PositionwiseFeedForward(nn.Module):
    """ A two-feed-forward-layer module """

    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_in, d_hid)  # position-wise
        self.w_2 = nn.Linear(d_hid, d_in)  # position-wise
        self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        residual = x

        x = self.w_2(F.relu(self.w_1(x)))
        x = self.dropout(x)
        x += residual

        x = self.layer_norm(x)

        return x


class EncoderLayer(nn.Module):
    """ Compose with two layers """

    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)

    def forward(self, enc_input, slf_attn_mask=None):
        enc_output, enc_slf_attn = self.slf_attn(
            enc_input, enc_input, enc_input, mask=slf_attn_mask)
        enc_output = self.pos_ffn(enc_output)
        return enc_output, enc_slf_attn


class Encoder(nn.Module):
    """ A encoder model with self attention mechanism. """

    def __init__(
            self, n_src_vocab, d_word_vec, n_layers, n_head, d_k, d_v,
            d_model, d_inner, pad_idx, pos_enc_dim=7, dropout=0.1, n_position=200, bench='101'):

        super().__init__()

        self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, padding_idx=pad_idx)
        self.bench = bench
        if self.bench == 'nasbench101':
            self.embedding_lap_pos_enc = nn.Linear(pos_enc_dim, d_word_vec)
        elif self.bench == 'nasbench201':
            self.pos_map = nn.Linear(pos_enc_dim, n_src_vocab+1)
            self.embedding_lap_pos_enc = nn.Linear(pos_enc_dim, d_word_vec)
        elif self.bench == 'darts_ss':
            self.pos_map = nn.Linear(pos_enc_dim*2, pos_enc_dim*4)
            self.embedding_lap_pos_enc = nn.Linear(pos_enc_dim, d_word_vec)
        elif self.bench == 'mobilenet_ss':
            self.pos_map = PositionalEncoding(d_word_vec, n_position=21)

        self.dropout = nn.Dropout(p=dropout)
        self.layer_stack = nn.ModuleList([
            EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
            for _ in range(n_layers)])
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

    def forward(self, src_seq, pos_seq, random_pos=None, src_mask=None, return_attns=False):

        enc_slf_attn_list = []

        # -- Forward

        enc_output = self.src_word_emb(src_seq)
        if self.bench == 'nasbench101':
            pos_output = self.embedding_lap_pos_enc(pos_seq)
            enc_output += pos_output
            enc_output = self.dropout(enc_output)
        elif self.bench == 'nasbench201':
            pos_output = self.pos_map(pos_seq).transpose(1, 2)
            pos_output = self.embedding_lap_pos_enc(pos_output)
            enc_output += pos_output
            enc_output = self.dropout(enc_output)
        elif self.bench == 'darts_ss':
            pos_output = self.pos_map(pos_seq.transpose(1, 2)).transpose(1, 2)
            pos_output = self.embedding_lap_pos_enc(pos_output)
            enc_output += pos_output
            enc_output = self.dropout(enc_output)
        elif self.bench == 'mobilenet_ss':
            src_output = self.src_word_emb(src_seq)
            enc_output = self.pos_map(src_output)
        else:
            raise ValueError('Un-defined bench!')

        # pos
        enc_output = self.layer_norm(enc_output)

        for enc_layer in self.layer_stack:
            enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=src_mask)
            enc_slf_attn_list += [enc_slf_attn] if return_attns else []

        if return_attns:
            return enc_output, enc_slf_attn_list
        return enc_output


def get_laplacian_matrix(adj):
    rowsum = adj.sum(2)
    degree_matrix = torch.zeros(adj.shape, device=adj.device)
    for i in range(len(adj)):
        degree_matrix[i] = torch.diag(rowsum[i])
    return torch.sub(degree_matrix, adj)


def get_features_index(out):
    opt_list = []
    for i in out:
        tmp = []
        for j in i:
            index = np.argwhere(j.cpu() == 1.)[0][0]
            tmp.append(index)
        opt_list.append(tmp)
    return torch.tensor(opt_list, device=out.device)


# class ClassicalTransformer(nn.Module):
class ClassicalTransformer(nn.Module):
    def __init__(self, adj_type, n_src_vocab, d_word_vec, n_layers, n_head, d_k, d_v,
                 d_model, d_inner, pad_idx=None, pos_enc_dim=7, dropout=0.1, linear_hidden=80, bench='nasbench101'):
        super(ClassicalTransformer, self).__init__()

        self.bench = bench
        self.adj_type = adj_type
        self.encoder = Encoder(n_src_vocab, d_word_vec, n_layers, n_head, d_k, d_v,
                               d_model, d_inner, pad_idx, pos_enc_dim=pos_enc_dim, dropout=0.1, bench=bench)

        self.dropout = nn.Dropout(0.1)
        self.fc1 = nn.Linear(d_word_vec, linear_hidden, bias=False)
        self.fc2 = nn.Linear(linear_hidden, 1, bias=False)

    def forward(self, inputs):
        if self.bench == 'mobilenet_ss':
            numv = inputs["num_vertices"]
            out = self.encoder(inputs["features"], None)
        else:
            numv = inputs["num_vertices"]
            adj_matrix = inputs["adjacency"]

            if self.adj_type == 'adj':
                pass
            elif self.adj_type == 'adj_nor':
                gs = adj_matrix.size(1)  # graph node number
                adj_matrix = normalize_adj(adj_matrix + torch.eye(gs, device=adj_matrix.device))
            elif self.adj_type == 'adj_lapla':
                adj_matrix = inputs["lapla"].float()
            elif self.adj_type == 'adj_lapla_nor':
                adj_matrix = inputs["lapla_nor"].float()

            elif self.adj_type == 'lapla':
                adj_matrix = get_laplacian_matrix(adj_matrix)
            elif self.adj_type == 'nor_lapla':
                gs = adj_matrix.size(1)  # graph node number
                adj_matrix = normalize_adj(adj_matrix + torch.eye(gs, device=adj_matrix.device))
                adj_matrix = get_laplacian_matrix(adj_matrix)
            else:
                raise ValueError('No Defined ADJ Type!')

            # out = out.long()
            # print("ope:", inputs["operations_oneshot"][0])
            out = get_features_index(inputs["operations_oneshot"].long())
            # out = inputs["operations"]
            # print("out:", out[0])
            out = self.encoder(src_seq=out, pos_seq=adj_matrix.float())
            # exit()

            # out = self.encoder(src_seq=inputs["features"], pos_seq=adj_matrix.float())

        out = graph_pooling(out, numv)
        out = self.fc1(out)
        out = self.dropout(out)
        out = self.fc2(out).view(-1)

        # return torch.sigmoid(out)       # 加入 sigmoid 效果不好.
        return out

'''
out: tensor([[1, 4, 5, 5, 2, 2, 2, 0],
        [1, 2, 5, 6, 5, 5, 6, 0],
        [1, 6, 2, 5, 6, 3, 4, 0],
        [1, 6, 6, 5, 5, 5, 6, 0],
        [1, 2, 5, 6, 3, 6, 5, 0],
        [1, 2, 3, 5, 5, 6, 5, 0],
        [1, 6, 2, 6, 3, 6, 2, 0],
        [1, 3, 5, 4, 6, 3, 5, 0],
        [1, 2, 4, 2, 3, 4, 6, 0],
        [1, 6, 6, 4, 4, 4, 6, 0]], device='cuda:0')

'''

class TransformerPredictor(object):

    def __init__(self, adj_type, n_src_vocab, pos_enc_dim, bench):

        self.predictor = ClassicalTransformer(
            adj_type=adj_type,
            n_src_vocab=n_src_vocab,
            pos_enc_dim=pos_enc_dim,
            d_word_vec=80,
            n_layers=3,
            n_head=4,
            d_k=64,
            d_v=64,
            d_model=80,
            d_inner=512,
            linear_hidden=96,
            dropout=0.3,
            bench=bench
        )

        self.mean = 0.0
        self.std = 0.0

    def fit(self,
            train_data,
            batch_size=10,
            epochs=300,
            lr=1e-4,
            wd=1e-3,
            ):

        ytrain = []
        for v in train_data:
            ytrain.append(v["val_acc"])

        self.mean = np.mean(ytrain)
        self.std = np.std(ytrain)
        ytrain_normed = (ytrain - self.mean) / self.std

        for i in range(len(train_data)):
            if train_data[i]["val_acc"] == ytrain[i]:
                train_data[i]["val_acc"] = ytrain_normed[i]
            else:
                raise Exception("not equals")

        data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)

        self.predictor.to(device)
        criterion = nn.MSELoss()
        optimizer = optim.Adam(self.predictor.parameters(), lr=lr, weight_decay=wd)
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
        self.predictor.train()

        for epoch in range(epochs):
            # meters = AverageMeterGroup()
            lr = optimizer.param_groups[0]["lr"]
            for step, batch in enumerate(data_loader):
                batch = to_cuda(batch, device)
                target = batch["val_acc"]
                prediction = self.predictor(batch)
                optimizer.zero_grad()
                loss = criterion(prediction, target)
                loss.backward()
                optimizer.step()

                # mse = accuracy_mse(prediction, target)
                # meters.update({"loss": loss.item(), "mse": mse.item()}, n=target.size(0))

            print("epoch: {}, loss: {}".format(epoch, loss.item()))
            # log.info("epoch: {}, loss: {}".format(epoch, loss.item()))
            lr_scheduler.step()

    def predict(self, test_data):
        self.predictor.eval()

        data_loader = DataLoader(test_data)
        with torch.no_grad():
            for step, batch in enumerate(data_loader):
                batch = to_cuda(batch, device)
                pred_val = self.predictor(batch)

        return pred_val.item() * self.std + self.mean

    # def get_test_loss(self, test_data, eval_batch_size=1000, log=None):
    #     test_data_loader = DataLoader(test_data, batch_size=eval_batch_size)
    #     criterion = nn.MSELoss()
    #
    #     self.predictor.eval()
    #
    #     predict_, target_ = [], []
    #     with torch.no_grad():
    #         for step, batch in enumerate(test_data_loader):
    #             batch = to_cuda_float32(batch)
    #             target = batch["val_acc"]
    #             pred_val = self.predictor(batch)
    #             predict_.append(pred_val.cpu().numpy())
    #             target_.append(target.cpu().numpy())
    #
    #             # predict_.append(predict.numpy())
    #             # target_.append(target.numpy())
    #             # meters.update({"loss": criterion(predict, target).item(),
    #             #                "mse": accuracy_mse(predict, target).item()}, n=target.size(0))
    #
    #             # if (args.eval_print_freq and step % args.eval_print_freq == 0) or \
    #             #         step % 10 == 0 or step + 1 == len(test_data_loader):
    #             #     logger.info("Evaluation Step [%d/%d]  %s", step + 1, len(test_data_loader), meters)
    #
    #     predict_ = np.concatenate(predict_)
    #     target_ = np.concatenate(target_)
    #     print("Func-get_test_loss:", "predict_:", predict_[:10], "target_:", target_[:10])
    #     # logger.info("Kendalltau: %.6f", kendalltau(predict_, target_)[0])
    #
    #     result = kendalltau(predict_, target_)[0]
    #     print("Kendalltau: {:.6f}".format(result))
    #     return result


class TransformerPredictorV2(Predictor):

    def __init__(self, encoding_type="transformer", ss_type=None, hpo_wrapper=False, dataset=None, bench_api=None):
        self.encoding_type = encoding_type
        if ss_type is not None:
            self.ss_type = ss_type
        self.hpo_wrapper = hpo_wrapper
        self.default_hyperparams = {
            "batch_size": 10,
            "lr": 1e-4,
            "wd": 1e-3,
        }
        self.hyperparams = None
        self.dataset = dataset

        self.bench_api = bench_api

    def get_model(self, **kwargs):

        if self.ss_type == "nasbench101":
            n_src_vocab = 5
            pos_enc_dim = 7

        elif self.ss_type == "nasbench201":
            n_src_vocab = 7
            pos_enc_dim = 8

        else:
            raise Exception("Invalid value:", self.ss_type)

        predictor = ClassicalTransformer(
            adj_type='lapla',
            n_src_vocab=n_src_vocab,
            pos_enc_dim=pos_enc_dim,
            d_word_vec=80,
            n_layers=3,
            n_head=4,
            d_k=64,
            d_v=64,
            d_model=80,
            d_inner=512,
            linear_hidden=96,
            dropout=0.3,
            bench=self.ss_type,
        )

        return predictor

    def fit(self, xtrain, ytrain, train_info=None, epochs=100):

        if self.hyperparams is None:
            self.hyperparams = self.default_hyperparams.copy()

        batch_size = self.hyperparams["batch_size"]
        lr = self.hyperparams["lr"]
        wd = self.hyperparams["wd"]

        # get mean and std, normlize accuracies
        self.mean = np.mean(ytrain)
        self.std = np.std(ytrain)
        ytrain_normed = (ytrain - self.mean) / self.std
        # encode data in gcn format
        train_data = []
        for i, arch in enumerate(xtrain):
            # encoded = encode(
            #     arch, encoding_type=self.encoding_type, ss_type=self.ss_type
            # )
            encoded = self.bench_api.dag_encoding(arch, dataset=self.dataset)

            # print("encoded:", encoded)
            # exit()

            # encoded["val_acc"] = float(ytrain_normed[i])
            encoded["val_acc"] = torch.tensor(ytrain_normed[i], dtype=torch.float32)
            train_data.append(encoded)

        train_data = np.array(train_data)

        data_loader = DataLoader(
            train_data, batch_size=batch_size, shuffle=True, drop_last=True
        )

        self.model = self.get_model()

        self.model.to(device)
        # criterion = nn.MSELoss().to(device)
        criterion = nn.MSELoss()
        optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=wd)
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

        self.model.train()

        for _ in range(epochs):
            meters = AverageMeterGroup()
            lr = optimizer.param_groups[0]["lr"]
            for _, batch in enumerate(data_loader):
                # target = batch["val_acc"].float().to(device)
                # print("batch:", batch)
                # exit()
                batch = to_cuda(batch, device)
                # print("batch:", batch)
                target = batch["val_acc"]
                prediction = self.model(batch)
                optimizer.zero_grad()
                # print("prediction:", prediction, "target:", target)
                loss = criterion(prediction, target)
                loss.backward()
                optimizer.step()
                mse = accuracy_mse(prediction, target)
                meters.update(
                    {"loss": loss.item(), "mse": mse.item()}, n=target.size(0)
                )

            lr_scheduler.step()

        train_pred = np.squeeze(self.query(xtrain))
        train_error = np.mean(abs(train_pred - ytrain))
        return train_error

    def query(self, xtest, info=None, eval_batch_size=1000):
        test_data = np.array(
            [
                self.bench_api.dag_encoding(arch, dataset=self.dataset)
                for arch in xtest
            ]
        )
        test_data_loader = DataLoader(test_data, batch_size=eval_batch_size)

        self.model.eval()
        pred = []
        with torch.no_grad():
            for _, batch in enumerate(test_data_loader):
                batch = to_cuda(batch, device)
                prediction = self.model(batch)
                pred.append(prediction.cpu().numpy())

        pred = np.concatenate(pred)
        return pred * self.std + self.mean
        # return pred

    def set_random_hyperparams(self):

        if self.hyperparams is None:
            params = self.default_hyperparams.copy()

        else:
            params = {
                "batch_size": int(loguniform(5, 32)),
                "lr": loguniform(0.00001, 0.1),
                "wd": loguniform(0.00001, 0.1),
            }

        self.hyperparams = params
        return params


if __name__ == '__main__':
    net = ClassicalTransformer(
        adj_type='adj',
        n_src_vocab=5,
        pos_enc_dim=7,
        d_word_vec=80,
        n_layers=3,
        n_head=4,
        d_k=64,
        d_v=64,
        d_model=80,
        d_inner=512,
        linear_hidden=96,
        dropout=0.3,
    )

    print(
        '# Model Parameters: %.3f M' %
        (sum(param.numel() for param in net.parameters()) / 1e6)
    )
